{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/style.css\">\n",
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HiJG9Do4_-sm"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "_fEE8rM9BUfS"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u71STQRgnQ3a"
      },
      "source": [
        "# Fine-tune PaliGemma with JAX and Flax\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "</td>\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "</td>\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "</td>\n",
        "</table>\n",
        "\n",
        "This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n",
        "\n",
        "### What's in this notebook\n",
        "\n",
        "This notebook uses the model reference implementation from [`big_vision`](https://github.com/google-research/big_vision)\n",
        "and shows how to:\n",
        "\n",
        " * Install dependencies, and download the PaliGemma model checkpoint and training data\n",
        " * Load the model onto GPU devices\n",
        " * Prepare the model's inputs for training and inference\n",
        " * Fine-tune the model\n",
        " * Inspect the output\n",
        "\n",
        "The training data for this notebook consists of 90 pairs of images and long captions describing them. To make it runnable on a T4 colab runtime, you'll only fine-tune the attention layers of the language model and freeze the other parameters.\n",
        "\n",
        "This example is for learning purposes only. In a real use case, the amount of data, trainable parameters, training steps and hyper-parameters, and obtained results could be significantly different.\n",
        "\n",
        "### Before you begin\n",
        "\n",
        "Before going through this notebook, you should be familiar with Python code, as well as how large language models (LLMs) are trained. You don't need to be familiar with JAX, but basic knowledge about JAX (or similar technologies such as Keras) is helpful when reading through the example code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6U0QUFveqSP2"
      },
      "source": [
        "## Setup\n",
        "\n",
        "The following sections explain the preliminary steps for getting a notebook to use a PaliGemma model, including model access, getting an API key, and configuring the notebook runtime."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qRi1rF4MWlQi"
      },
      "source": [
        "### Get access to PaliGemma\n",
        "\n",
        "Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n",
        "\n",
        "1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n",
        "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2) and click **Request Access**.\n",
        "1. Complete the consent form and accept the terms and conditions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "azmRZvgGyhAb"
      },
      "source": [
        "### Configure your API key\n",
        "\n",
        "To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.\n",
        "\n",
        "To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
        "\n",
        "Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kp6XQ2hQB8lv"
      },
      "source": [
        "### Select the runtime\n",
        "\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the PaliGemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, click the **▾ (Additional connection options)** dropdown menu.\n",
        "1. Select **Change runtime type**.\n",
        "1. Under **Hardware accelerator**, select **T4 GPU**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qOJ3BeYFVrOX"
      },
      "source": [
        "### Set environment variables\n",
        "\n",
        "Set the environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zGLIp1Cx3_CX"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
        "\n",
        "# The T4 runtime is tight on memory to finetune this model. Preallocate\n",
        "# all memory ahead of time to avoid out-of-memory due to fragmentation.\n",
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rCd__uzW_eK-"
      },
      "source": [
        "### Fetch the `big_vision` repository and install related dependencies\n",
        "\n",
        "Download the `big_vision` repository to your Colab notebook from GitHub and install dependencies related to `big_vision` by running the following code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DfxKb3F839Ks"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "# TPUs with\n",
        "if \"COLAB_TPU_ADDR\" in os.environ:\n",
        "  raise \"It seems you are using Colab with remote TPUs which is not supported.\"\n",
        "\n",
        "# Fetch big_vision repository if python doesn't know about it and install\n",
        "# dependencies needed for this notebook.\n",
        "if not os.path.exists(\"big_vision_repo\"):\n",
        "  !git clone --quiet --branch=main --depth=1 \\\n",
        "     https://github.com/google-research/big_vision big_vision_repo\n",
        "\n",
        "# Append big_vision code to python import path\n",
        "if \"big_vision_repo\" not in sys.path:\n",
        "  sys.path.append(\"big_vision_repo\")\n",
        "\n",
        "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n",
        "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zDoq0O77GF30"
      },
      "source": [
        "### Import JAX and other dependencies\n",
        "\n",
        "Import JAX and other dependencies required for PaliGemma, like TensorFlow and NumPy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dTfe2k8J4Bw0"
      },
      "outputs": [],
      "source": [
        "import base64\n",
        "import functools\n",
        "import html\n",
        "import io\n",
        "import os\n",
        "import warnings\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import ml_collections\n",
        "\n",
        "import tensorflow as tf\n",
        "import sentencepiece\n",
        "\n",
        "from IPython.core.display import display, HTML\n",
        "from PIL import Image\n",
        "\n",
        "# Import model definition from big_vision\n",
        "from big_vision.models.proj.paligemma import paligemma\n",
        "from big_vision.trainers.proj.paligemma import predict_fns\n",
        "\n",
        "# Import big vision utilities\n",
        "import big_vision.datasets.jsonl\n",
        "import big_vision.utils\n",
        "import big_vision.sharding\n",
        "\n",
        "# Don't let TF use the GPU or TPUs\n",
        "tf.config.set_visible_devices([], \"GPU\")\n",
        "tf.config.set_visible_devices([], \"TPU\")\n",
        "\n",
        "backend = jax.extend.backend.get_backend()\n",
        "print(f\"JAX version:  {jax.__version__}\")\n",
        "print(f\"JAX platform: {backend.platform}\")\n",
        "print(f\"JAX devices:  {jax.device_count()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b9kSadtIhjlX"
      },
      "source": [
        "## Download and configure the model\n",
        "\n",
        "In this step, you'll download the model checkpoint and configure it so that you can fine-tune it later on. This step shows you how to move model parameters into TPU memory, which is useful for fine-tuning models on devices with limited resources."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7tvcc0oQHl4v"
      },
      "source": [
        "### Download the model checkpoint\n",
        "\n",
        "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma-2/jax/paligemma2-3b-pt-224).\n",
        "\n",
        "Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gQNOTfF24AV4"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import kagglehub\n",
        "\n",
        "# Use these for PaliGemma-2 3B 224px²\n",
        "LLM_VARIANT = \"gemma2_2b\"\n",
        "MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n",
        "KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\"  # Path to fetch from Kaggle.\n",
        "\n",
        "# Use these for PaliGemma 1:\n",
        "# LLM_VARIANT = \"gemma_2b\"\n",
        "# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n",
        "# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n",
        "\n",
        "if not os.path.exists(MODEL_PATH):\n",
        "  print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n",
        "  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n",
        "  print(f\"Model path: {MODEL_PATH}\")\n",
        "\n",
        "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n",
        "if not os.path.exists(TOKENIZER_PATH):\n",
        "  print(\"Downloading the model tokenizer...\")\n",
        "  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n",
        "  print(f\"Tokenizer path: {TOKENIZER_PATH}\")\n",
        "\n",
        "DATA_DIR=\"./longcap100\"\n",
        "if not os.path.exists(DATA_DIR):\n",
        "  print(\"Downloading the dataset...\")\n",
        "  !gsutil -m -q cp -n -r gs://longcap100/ .\n",
        "  print(f\"Data path: {DATA_DIR}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rv7w-cGuLj5o"
      },
      "source": [
        "### Configure the model\n",
        "\n",
        "It's time to actually start configuring the model that you're going to use.\n",
        "\n",
        "For this notebook, you need to be able to fit your model onto a T4 GPU. Having a limited resource like space constraints means that you have to be mindful of how your model is configured.\n",
        "\n",
        "If you fine-tune every parameter, your model won't be able to run in the notebook environment. As a result, in this part of the notebook, you'll configure your model so that it has the ability to freeze some of the parameters, and only fine-tune the parameters that really need to be fine-tuned for the model to give you accurate results. In LLMs, parameters are said to be *frozen* when they are no longer actively being used to train the model.\n",
        "\n",
        "In order to configure your model, you need to:\n",
        "\n",
        "* Initialize the `model_config` as a [`FrozenConfigDict`](https://github.com/google/ml_collections/tree/master#frozenconfigdict) so that you can freeze some of the parameters and keep memory usage low\n",
        "* Initialize an instance of the PaliGemma `Model` class using the `model_config` as its configurations\n",
        "* Load the model parameters into RAM\n",
        "* Define a `decode` function to sample outputs from the model\n",
        "\n",
        "This code in this cell takes about a minute to run to completion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1aghcULcEdtv"
      },
      "outputs": [],
      "source": [
        "# Define model\n",
        "\n",
        "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n",
        "# for better transfer results.\n",
        "model_config = ml_collections.FrozenConfigDict({\n",
        "    \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
        "    \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n",
        "})\n",
        "model = paligemma.Model(**model_config)\n",
        "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n",
        "\n",
        "# Load params - this can take up to 1 minute in T4 colabs.\n",
        "params = paligemma.load(None, MODEL_PATH, model_config)\n",
        "\n",
        "# Define `decode` function to sample outputs from the model.\n",
        "decode_fn = predict_fns.get_all(model)['decode']\n",
        "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uidBwmb8LwZ5"
      },
      "source": [
        "### Move model parameters into GPU/TPU memory\n",
        "\n",
        "Now you need to move the model parameters into GPU/TPU memory. First, shard the parameters across the available GPUs, then load the parameters. Here, you'll load the parameters sequentially. This process takes longer than loading them simultaneously, but it requires more RAM than you have available in this notebook.\n",
        "\n",
        "Finally, print out all of the parameters to see what type each individual parameter is cast to. Frozen parameters are kept as `float16`, while the trainable parameters are cast to `float32`. When you inspect the list, you'll see that most of the parameters have been frozen and are `float16`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RWOdf_fw2SAO"
      },
      "outputs": [],
      "source": [
        "# Create a pytree mask of the trainable params.\n",
        "def is_trainable_param(name, param):  # pylint: disable=unused-argument\n",
        "  if name.startswith(\"llm/layers/attn/\"):  return True\n",
        "  if name.startswith(\"llm/\"):              return False\n",
        "  if name.startswith(\"img/\"):              return False\n",
        "  raise ValueError(f\"Unexpected param name {name}\")\n",
        "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n",
        "\n",
        "# If more than one device is available (e.g. multiple GPUs) the parameters can\n",
        "# be sharded across them to reduce HBM usage per device.\n",
        "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n",
        "\n",
        "data_sharding = jax.sharding.NamedSharding(\n",
        "    mesh, jax.sharding.PartitionSpec(\"data\"))\n",
        "\n",
        "params_sharding = big_vision.sharding.infer_sharding(\n",
        "    params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n",
        "\n",
        "# Yes: Some donated buffers are not usable.\n",
        "warnings.filterwarnings(\n",
        "    \"ignore\", message=\"Some donated buffers were not usable\")\n",
        "\n",
        "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n",
        "def maybe_cast_to_f32(params, trainable):\n",
        "  # Cast others to float16, since some GPUs don't support bf16.\n",
        "  return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n",
        "                      if m else p.astype(jnp.float16),\n",
        "                      params, trainable)\n",
        "\n",
        "# Loading all params in simultaneous - albeit much faster and more succinct -\n",
        "# requires more RAM than the T4 colab runtimes have by default.\n",
        "# Instead, do it param by param.\n",
        "params, treedef = jax.tree.flatten(params)\n",
        "sharding_leaves = jax.tree.leaves(params_sharding)\n",
        "trainable_leaves = jax.tree.leaves(trainable_mask)\n",
        "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n",
        "  params[idx] = big_vision.utils.reshard(params[idx], sharding)\n",
        "  params[idx] = maybe_cast_to_f32(params[idx], trainable)\n",
        "  params[idx].block_until_ready()\n",
        "params = jax.tree.unflatten(treedef, params)\n",
        "\n",
        "# Print params to show what the model is made of.\n",
        "def parameter_overview(params):\n",
        "  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n",
        "    print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n",
        "\n",
        "print(\" == Model params == \")\n",
        "parameter_overview(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iD_9XXQkn1Mv"
      },
      "source": [
        "## Prepare to tune the model\n",
        "\n",
        "Now that your model is configured, you can tune it. In this step, you'll create your model's inputs as well as the training and validation iterators, view the training examples, and define the training and validation loops."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "83ZcnbddJKdx"
      },
      "source": [
        "### Create model inputs\n",
        "\n",
        "The model checkpoint you're using has already been trained on images of various aspect ratios that have been resized to 224x224 pixels, and to handle tokenized texts.\n",
        "\n",
        "The code below defines three functions that you'll use in the next step create the model's inputs:\n",
        "\n",
        "* **`preprocess_image`:** Normalizes the image data. In this case, pre-processing converts the passed-in image to greyscale, removes the alpha layer, and resizes the passed-in image to the size required by the model for image inputs (224x224 pixels).\n",
        "* **`preprocess_tokens`:** Splits the tokens up and adds flags to mark whether a token is a prefix or suffix token. These flags will be used later on in the code, during the training step and the evaluation loop.\n",
        "* **`postprocess_tokens`:** Removes any tokens left at and/or after the end-of-sequence (EOS) token and returns the remaining decoded tokens.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8SRW0NuU4UcW"
      },
      "outputs": [],
      "source": [
        "def preprocess_image(image, size=224):\n",
        "  # Model has been trained to handle images of different aspects ratios\n",
        "  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n",
        "  # options are helpful to improve quality in some tasks.\n",
        "  image = np.asarray(image)\n",
        "  if image.ndim == 2:  # Convert image without last channel into greyscale.\n",
        "    image = np.stack((image,)*3, axis=-1)\n",
        "  image = image[..., :3]  # Remove alpha layer.\n",
        "  assert image.shape[-1] == 3\n",
        "\n",
        "  image = tf.constant(image)\n",
        "  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n",
        "  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]\n",
        "\n",
        "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n",
        "  # Model has been trained to handle tokenized text composed of a prefix with\n",
        "  # full attention and a suffix with causal attention.\n",
        "  separator = \"\\n\"\n",
        "  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n",
        "  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.\n",
        "  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.\n",
        "\n",
        "  if suffix:\n",
        "    suffix = tokenizer.encode(suffix, add_eos=True)\n",
        "    tokens += suffix\n",
        "    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.\n",
        "    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.\n",
        "\n",
        "  mask_input = [1] * len(tokens)    # 1 if it's a token, 0 if padding.\n",
        "  if seqlen:\n",
        "    padding = [0] * max(0, seqlen - len(tokens))\n",
        "    tokens = tokens[:seqlen] + padding\n",
        "    mask_ar = mask_ar[:seqlen] + padding\n",
        "    mask_loss = mask_loss[:seqlen] + padding\n",
        "    mask_input = mask_input[:seqlen] + padding\n",
        "\n",
        "  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n",
        "\n",
        "def postprocess_tokens(tokens):\n",
        "  tokens = tokens.tolist()  # np.array to list[int]\n",
        "  try:  # Remove tokens at and after EOS if any.\n",
        "    eos_pos = tokens.index(tokenizer.eos_id())\n",
        "    tokens = tokens[:eos_pos]\n",
        "  except ValueError:\n",
        "    pass\n",
        "  return tokenizer.decode(tokens)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovgWBgdHJZq3"
      },
      "source": [
        "### Create the training and validation iterators\n",
        "\n",
        "Create two iterators:\n",
        "\n",
        "*   A **training iterator** to allow the training process to go through the data in chunks rather than processing it all at once\n",
        "    *   This allows you to do some data pre-processing before use\n",
        "*   A **validation iterator** that allows the training process to iterate over the validation dataset to see how well the tuned model aligned with the provided results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "whzWOojGOtzi"
      },
      "outputs": [],
      "source": [
        "SEQLEN = 128\n",
        "\n",
        "train_dataset = big_vision.datasets.jsonl.DataSource(\n",
        "    os.path.join(DATA_DIR, \"data_train90.jsonl\"),\n",
        "    fopen_keys={\"image\": DATA_DIR})\n",
        "\n",
        "val_dataset = big_vision.datasets.jsonl.DataSource(\n",
        "    os.path.join(DATA_DIR, \"data_val10.jsonl\"),\n",
        "    fopen_keys={\"image\": DATA_DIR})\n",
        "\n",
        "\n",
        "def train_data_iterator():\n",
        "  \"\"\"Never ending iterator over training examples.\"\"\"\n",
        "  # Shuffle examples and repeat so one can train for many epochs.\n",
        "  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n",
        "  for example in dataset.as_numpy_iterator():\n",
        "    image = Image.open(io.BytesIO(example[\"image\"]))\n",
        "    image = preprocess_image(image)\n",
        "\n",
        "    prefix = \"caption en\"  # Could also be a different prefix per example.\n",
        "    suffix = example[\"suffix\"].decode().lower()\n",
        "    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n",
        "\n",
        "    yield {\n",
        "        \"image\": np.asarray(image),\n",
        "        \"text\": np.asarray(tokens),\n",
        "        \"mask_ar\": np.asarray(mask_ar),\n",
        "        \"mask_loss\": np.asarray(mask_loss),\n",
        "    }\n",
        "\n",
        "\n",
        "def validation_data_iterator():\n",
        "  \"\"\"Single iterator over validation examples.\"\"\"\n",
        "  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n",
        "    image = Image.open(io.BytesIO(example[\"image\"]))\n",
        "    image = preprocess_image(image)\n",
        "\n",
        "    prefix = \"caption en\"  # Could also be a different prefix per example.\n",
        "    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n",
        "\n",
        "    yield {\n",
        "        \"image\": np.asarray(image),\n",
        "        \"text\": np.asarray(tokens),\n",
        "        \"mask_ar\": np.asarray(mask_ar),\n",
        "        \"mask_input\": np.asarray(mask_input),\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84olaM5dCiAl"
      },
      "source": [
        "### View training examples\n",
        "\n",
        "In this notebook, the training data contains 90 images that are paired with long descriptions of what's depicted in the image.\n",
        "\n",
        "**Note:** Normal training data sets that are meant to be used for practical use cases should contain more images, but this notebook limits the number of data points so that you can train the model in a reasonable amount of time for an example.\n",
        "\n",
        "The code below prints a random selection of images with their descriptions from the training data set so that you can see what the images and descriptions your model is trained on looks like. Each image is displayed in as a 128x128 pixel JPEG, with the description printed next to the image to the right."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BzJfb5t0nsLq"
      },
      "outputs": [],
      "source": [
        "def render_inline(image, resize=(128, 128)):\n",
        "  \"\"\"Convert image into inline html.\"\"\"\n",
        "  image = Image.fromarray(image)\n",
        "  image.resize(resize)\n",
        "  with io.BytesIO() as buffer:\n",
        "    image.save(buffer, format='jpeg')\n",
        "    image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n",
        "    return f\"data:image/jpeg;base64,{image_b64}\"\n",
        "\n",
        "def render_example(image, caption):\n",
        "  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]\n",
        "  return f\"\"\"\n",
        "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
        "        <img style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" />\n",
        "        <p style=\"width:256px; margin:10px; font-size:small;\">{html.escape(caption)}</p>\n",
        "    </div>\n",
        "    \"\"\"\n",
        "\n",
        "html_out = \"\"\n",
        "for idx, example in zip(range(8), train_data_iterator()):\n",
        "  caption = postprocess_tokens(example[\"text\"])  # detokenize model input.\n",
        "  caption = caption[len(\"caption en\\n\"):]        # strip prefix\n",
        "  html_out += render_example(example[\"image\"], caption)\n",
        "\n",
        "print(\"Training examples\")\n",
        "display(HTML(html_out))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N2BwpXkfI8OT"
      },
      "source": [
        "### Define the training and evaluation loops\n",
        "\n",
        "Define the training loop to train the model on the provided dataset, and the evaluation loop to look at all of the examples in the validation dataset and make its predictions.\n",
        "\n",
        "#### Defining the training loop\n",
        "\n",
        "The `update_fn` function defines the training step. During the training step, the loss per example is calculated and stochastic gradient descent (SGD) is applied to the trainable parameters.\n",
        "\n",
        "Recall that earlier in the notebook, you included flags in the `preprocess_tokens` function that included `mask_loss`. You'll use the `mask_loss` flag here to exclude prefix and padded tokens from the loss. Without it, the loss calculation will be skewed. You also need to normalize each example, since each of them has a different number of tokens. After the prefix and padded tokens have been excluded and the examples have been normalized, you can calculate the loss per example.\n",
        "\n",
        "The training step also includes a function to apply an SGD to optimize the training.\n",
        "\n",
        "#### Defining the evaluation loop\n",
        "\n",
        "The `make_predictions` function is your evaluation loop. The evaluation loop is fairly straight forward with one notable change. If you recall from the beginning of the notebook, you only have 90 examples in your training data set. This is a very small amount of training examples, and your model ends up not having enough examples for the batch size when you run the training. This means that in the evaluation loop, you need to pad the batch by repeating examples.\n",
        "\n",
        "To make sure that your evaluation loop only counts actual examples and not the padded examples, you have to apply a mask to the padded examples that excludes them from the output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dwUV_imW3WQJ"
      },
      "outputs": [],
      "source": [
        "# The main update_fn using a simple stochastic gradient descent (SGD).\n",
        "@functools.partial(jax.jit, donate_argnums=(0,))\n",
        "def update_fn(params, batch, learning_rate):\n",
        "  imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n",
        "\n",
        "  def loss_fn(params):\n",
        "    text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n",
        "    logp = jax.nn.log_softmax(text_logits, axis=-1)\n",
        "\n",
        "    # The model takes as input txts[:, :-1] but the loss is defined as predicting\n",
        "    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n",
        "    # are part of the loss (e.g. prefix and padded tokens are not included).\n",
        "    mask_loss = batch[\"mask_loss\"][:, 1:]\n",
        "    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n",
        "\n",
        "    # Compute the loss per example. i.e. the mean of per token pplx.\n",
        "    # Since each example has a different number of tokens, normalize it.\n",
        "    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.\n",
        "    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.\n",
        "    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.\n",
        "\n",
        "    # batch_loss: mean of per example loss.\n",
        "    return jnp.mean(example_loss)\n",
        "\n",
        "  loss, grads = jax.value_and_grad(loss_fn)(params)\n",
        "\n",
        "  # Apply gradients to trainable params using SGD.\n",
        "  def apply_grad(param, gradient, trainable):\n",
        "    if not trainable: return param\n",
        "    return param - learning_rate * gradient\n",
        "\n",
        "  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n",
        "\n",
        "  return params, loss\n",
        "\n",
        "# Evaluation/inference loop.\n",
        "def make_predictions(data_iterator, *, num_examples=None,\n",
        "                     batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n",
        "  outputs = []\n",
        "  while True:\n",
        "    # Construct a list of examples in the batch.\n",
        "    examples = []\n",
        "    try:\n",
        "      for _ in range(batch_size):\n",
        "        examples.append(next(data_iterator))\n",
        "        examples[-1][\"_mask\"] = np.array(True)  # Indicates true example.\n",
        "    except StopIteration:\n",
        "      if len(examples) == 0:\n",
        "        return outputs\n",
        "\n",
        "    # Not enough examples to complete a batch. Pad by repeating last example.\n",
        "    while len(examples) % batch_size:\n",
        "      examples.append(dict(examples[-1]))\n",
        "      examples[-1][\"_mask\"] = np.array(False)  # Indicates padding example.\n",
        "\n",
        "    # Convert list of examples into a dict of np.arrays and load onto devices.\n",
        "    batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n",
        "    batch = big_vision.utils.reshard(batch, data_sharding)\n",
        "\n",
        "    # Make model predictions\n",
        "    tokens = decode({\"params\": params}, batch=batch,\n",
        "                    max_decode_len=seqlen, sampler=sampler)\n",
        "\n",
        "    # Fetch model predictions to device and detokenize.\n",
        "    tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n",
        "    tokens = tokens[mask]  # remove padding examples.\n",
        "    responses = [postprocess_tokens(t) for t in tokens]\n",
        "\n",
        "    # Append to html output.\n",
        "    for example, response in zip(examples, responses):\n",
        "      outputs.append((example[\"image\"], response))\n",
        "      if num_examples and len(outputs) >= num_examples:\n",
        "        return outputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n9r9V1jwJvu9"
      },
      "source": [
        "## Tune the model\n",
        "\n",
        "Now that you've set everything up and taken a look at the training data, it's time to finally tune the model. The code below runs the training loop for the model for 64 steps and prints the learning rate (`lr` in the printed output) and loss rate for each step.\n",
        "\n",
        "Every 16 steps, the model prints what its predictions are at that step in the training. This code prints out predictions for the same set of images so that you can see the model's ability to predict descriptions improve over time.\n",
        "\n",
        "At earlier steps in the training, there's likely issues with the descriptions, such as repeated sentences as the model gets stuck in its predictive loop or unfinished sentences. The model's predictions become steadily more accurate as training progresses. By step 64, the model's predictions should closely resemble the descriptions provided by the training data.\n",
        "\n",
        "This process takes around 15 minutes to complete on T4 TPUs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "067wj_6bZAG3"
      },
      "outputs": [],
      "source": [
        "# Run a short training loop with cosine learning rate schedule.\n",
        "#\n",
        "# Note: the first step can be quite slow on some machines (up to several minutes)\n",
        "# due to XLA compilation of the jax.jit'd function.\n",
        "#\n",
        "%%time\n",
        "\n",
        "BATCH_SIZE = 8\n",
        "TRAIN_EXAMPLES = 512\n",
        "LEARNING_RATE = 0.03\n",
        "\n",
        "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n",
        "EVAL_STEPS = TRAIN_STEPS // 4\n",
        "\n",
        "train_data_it = train_data_iterator()\n",
        "\n",
        "sched_fn = big_vision.utils.create_learning_rate_schedule(\n",
        "    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n",
        "    decay_type=\"cosine\", warmup_percent=0.10)\n",
        "\n",
        "for step in range(1, TRAIN_STEPS+1):\n",
        "  # Make list of N training examples.\n",
        "  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n",
        "\n",
        "  # Convert list of examples into a dict of np.arrays and load onto devices.\n",
        "  batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n",
        "  batch = big_vision.utils.reshard(batch, data_sharding)\n",
        "\n",
        "  # Training step and report training loss\n",
        "  learning_rate = sched_fn(step)\n",
        "  params, loss = update_fn(params, batch, learning_rate)\n",
        "\n",
        "  loss = jax.device_get(loss)\n",
        "  print(f\"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}\")\n",
        "\n",
        "  if (step % EVAL_STEPS) == 0:\n",
        "    print(f\"Model predictions at step {step}\")\n",
        "    html_out = \"\"\n",
        "    for image, caption in make_predictions(\n",
        "        validation_data_iterator(), num_examples=4, batch_size=4):\n",
        "      html_out += render_example(image, caption)\n",
        "    display(HTML(html_out))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "glScsFLVJ52c"
      },
      "source": [
        "## Output\n",
        "\n",
        "The validation data for this notebook consists of just 10 images. In normal code, you would likely have many more data points for validation, but for this notebook, run the following code to generate descriptions for all 10 images. After tuning the model, these descriptions should be very similar in form and content coverage to the descriptions included with the training data that you looked at earlier in this notebook.\n",
        "\n",
        "Run the below code to generate descriptions for the validation data set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hgUhEKjzPdMQ"
      },
      "outputs": [],
      "source": [
        "# The validation data consists of 10 images in a different domain than training\n",
        "# data.\n",
        "%%time\n",
        "\n",
        "print(\"Model predictions\")\n",
        "html_out = \"\"\n",
        "for image, caption in make_predictions(validation_data_iterator(), batch_size=4):\n",
        "  html_out += render_example(image, caption)\n",
        "display(HTML(html_out))\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "fine-tuning-paligemma.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
